import os
import random
import json
from tqdm import tqdm
import argparse
import pathlib
from load_aokvqa import load_aokvqa, get_coco_path
import ollama

random.seed(0)

def get_qwen_result(image_path, prompt, args):

    messages = [{
        "role": "user",
        "content": prompt
    }]

    if os.path.exists(image_path):  # 这里可以添加条件检查，以确保路径有效
        messages[0]["images"] = [image_path]

    response = ollama.chat(
        model="mistral-small3.1:24b",
        stream=False,
        messages=messages,
        options={
            "temperature": args.temperature,
            "max_tokens": args.max_tokens,
            "top_p": args.top_p,
            "frequency_penalty": args.frequency_penalty,
            "presence_penalty": args.presence_penalty,
            "stop": ["\n"]
        }
    )

    output = response['message']['content']

    return output

def prompt_element(d, context=None, include_choices=False, answer=False):
    return (f"Context: {context}\n" if context is not None else '') + \
            f"Q: {d['question']}\n" + \
           (f"Options: {', '.join(d['choices'])}.\n" if include_choices else '') + \
            f"A:" + (f" {d['choices'][d['correct_choice_idx']]}" if answer else '')

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
    parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
    parser.add_argument('--n', type=int, default=10, dest='num_examples')
    parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
    parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix')
    parser.add_argument('--include-choices', action='store_true', dest='include_choices')
    parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
    parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
    parser.add_argument('--temperature', type=float, default=0.5)
    parser.add_argument('--max_tokens',
                        type=int,
                        default=512,
                        help='The maximum number of tokens allowed for the generated answer.')
    parser.add_argument('--top_p', type=float, default=1.0)
    parser.add_argument('--frequency_penalty', type=float, default=0.0)
    parser.add_argument('--presence_penalty', type=float, default=0.0)
    args = parser.parse_args()

    coco_dir = "/home/test/yxl/MCoT/data/COCO"
    train_set = load_aokvqa(args.aokvqa_dir, 'train')


    train_context = {}
    context = {}
    if args.context_file is not None:
        train_context = json.load(args.train_context_file)
        context = json.load(args.context_file)

    predictions = {}

    for d in tqdm(train_set):
        q = d['question_id']

        prompt = args.prompt_prefix

        prompt += "Please answer the following question in the form \"The correct answer is ,because\" \n\n"

        prompt += prompt_element(d,
                                 context=context.get(q, None),
                                 include_choices=True,
                                 answer=False
                                 )
        image_path = get_coco_path('train', d['image_id'], coco_dir)
        output = get_qwen_result(image_path, prompt, args)
        predictions[q] = output

    json.dump(predictions, args.output_file)



if __name__ == '__main__':
    main()